import torch
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torchvision
from PIL import Image
from xml.dom.minidom import parse
import utils
import transforms as T
from engine import train_one_epoch, evaluate
import xml.etree.cElementTree as ET
import collections
import pandas as pd


names = {'background': 0, 'person': 1, 'bird': 2, 'cat': 3, 'cow': 4, 'dog': 5, 'horse': 6, 'sheep': 7, 'aeroplane': 8,
         'bicycle': 9, 'boat': 10, 'bus': 11, 'car': 12, 'motorbike': 13, 'train': 14, 'bottle': 15, 'chair': 16, 'diningtable': 17,
         'pottedplant': 18, 'sofa': 19, 'tvmonitor': 20}


# 制作数据集的函数，这里继承一个类，然后用这个类生成一个对象，传入root目录和一定的参数
class MarkDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        # 初始化对象，这里包括数据集的一些信息
        # root 根目录路径
        # transform 设置自己的transform函数
        # img 图片文件名的列表
        # bbox_xml 标注文件名的列表
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        self.imgs = list(os.listdir(os.path.join(root, "JPEGImages")))
        self.bbox_xml = list(os.listdir(os.path.join(root, "Annotations")))

    def __getitem__(self, idx):
        # load images and bbox
        # img_path 每一个对象的图像路径
        # bbox_xml_path 每一个标注文件的图像路径
        img_path = os.path.join(self.root, "JPEGImages", self.imgs[idx])
        bbox_xml_path = os.path.join(self.root, "Annotations", self.bbox_xml[idx])

        # 打开当前一张图片
        img = Image.open(img_path).convert("RGB")

        # 读取文件，VOC格式的数据集的标注是xml格式的文件
        # 打开一个xml标签文件
        dom = parse(bbox_xml_path)
        # 获取文档元素对象
        data = dom.documentElement
        # 获取 objects标签
        objects = data.getElementsByTagName('object')
        # get bounding box coordinates
        boxes = []
        labels = []
        for object_ in objects:
            # 获取标签中内容
            name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue  # 就是label，mark_type_1或mark_type_2
            # print(idx)
            # print(bbox_xml_path)
            # print(name)
            # 这里的label赋值有待商榷,可能是利用type_1里面的1来弄得，这里可能要改
            # labels.append(np.int(name[-1]))  # 背景的label是0，mark_type_1和mark_type_2的label分别是1和2
            # 这里直接将name作为label添加上去
            t_label = names[name]
            # labels.append(name[-1])
            labels.append(t_label)
            # 返回的应该是一个列表，但是这里只有一个bndbox，但是仍然要用下标0来获得第一个的对象
            bndbox = object_.getElementsByTagName('bndbox')[0]
            xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
            ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
            xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
            ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
            # 一次性附加上4个值
            boxes.append([xmin, ymin, xmax, ymax])

        # 例程中要求，将target中所有的东西转换为tensor
        # 转换为tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # 这里不止一个class了，所以不能用ones，要用as_tensor
        labels = torch.as_tensor(labels, dtype=torch.int64)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        # 这里应该是，默认所有的东西都没有遮挡、但事实上是，这里可能需要修改
        iscrowd = torch.zeros((len(objects),), dtype=torch.int64)

        # 制作数据集的target部分，这个部分包括以下几样东西
        # 只需要boxes和labels就可以了，但是这里还有很多东西？
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        # 由于训练的是目标检测网络，因此没有教程中的target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            # 注意这里target(包括bbox)也转换\增强了，和from torchvision import的transforms的不同
            # https://github.com/pytorch/vision/tree/master/references/detection 的 transforms.py里就有RandomHorizontalFlip时target变换的示例
            # 这里的transform函数就是自己传入的一个函数
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)


# 另一种，类官网文档方式的数据集制作类
class VOCCustomData(torchvision.datasets.vision.VisionDataset):
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.

    Args:
        root (string): Root directory of the custom VOC Dataset which includes directories
            Annotations and JPEGImages

        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, required): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 transforms=None):
        super(VOCCustomData, self).__init__(root, transforms, transform, target_transform)
        self.root = root
        self._transforms = transforms

        voc_root = self.root
        self.image_dir = os.path.join(voc_root, 'JPEGImages')
        self.annotation_dir = os.path.join(voc_root, 'Annotations')

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' Please verify the correct Dataset!')
        file_names = []

        for imgs in os.listdir(self.image_dir):
            file_names.append(imgs.split('.')[0])

        images_file = pd.DataFrame(file_names, index=None)
        # 保存图像路径,注意只有文件名,不带后缀和文件路径
        # images_file.to_csv(voc_root + '/imagesetfile.txt', header=False, index=False)

        self.images = [os.path.join(self.image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(self.annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert('RGB')

        target = self.parse_voc_xml(
            ET.parse(self.annotations[index]).getroot())

        target = dict(image_id=index, annotations=target['annotation'])

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.images)

    def parse_voc_xml(self, node):
        voc_dict = {}
        children = list(node)
        if children:
            def_dic = collections.defaultdict(list)
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            voc_dict = {
                node.tag:
                    {ind: v[0] if len(v) == 1 else v
                     for ind, v in def_dic.items()}
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict


# 数据增强，其实也是数据处理的必须步骤，至少需要一个ToTensor变换
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor

    # 猜测，是用append的形式 列表的形式，依次完成各种transform
    # 这里就是一个函数的列表
    transforms.append(T.ToTensor())
    # if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        # 50%的概率水平翻转
        # 暂时不用水平翻转
        # transforms.append(T.RandomHorizontalFlip(0.5))

    # 最后用Compose吧这个列表组合一下，成为可以torch可以读取的形式
    return T.Compose(transforms)


def showbbox(model, img):
    # 输入的img是0-1范围的tensor
    model.eval()
    with torch.no_grad():
        '''
        prediction形如：
        [{'boxes': tensor([[1492.6672,  238.4670, 1765.5385,  315.0320],
        [ 887.1390,  256.8106, 1154.6687,  330.2953]], device='cuda:0'), 
        'labels': tensor([1, 1], device='cuda:0'), 
        'scores': tensor([1.0000, 1.0000], device='cuda:0')}]
        '''
        prediction = model([img.to('cuda')])

    print(prediction)

    img = img.permute(1, 2, 0)  # C,H,W → H,W,C，用来画图
    img = (img * 255).byte().data.cpu()  # * 255，float转0-255
    img = np.array(img)  # tensor → ndarray

    for i in range(prediction[0]['boxes'].cpu().shape[0]):
        xmin = round(prediction[0]['boxes'][i][0].item())
        ymin = round(prediction[0]['boxes'][i][1].item())
        xmax = round(prediction[0]['boxes'][i][2].item())
        ymax = round(prediction[0]['boxes'][i][3].item())

        label = prediction[0]['labels'][i].item()

        if label == 1:
            cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0), thickness=2)
            cv2.putText(img, 'mark_type_1', (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0),
                        thickness=2)
        elif label == 2:
            cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), thickness=2)
            cv2.putText(img, 'mark_type_2', (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0),
                        thickness=2)

    plt.figure(figsize=(20, 15))
    plt.imshow(img)


def main(args):
    print(args, '\n')
    root = r'VOC2007'
    # root = os.path.join(root, 'VOC2012')

    # 选择GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # 设置分类数量，因为有background，所以会比分类多一类
    num_classes = 21

    # use our dataset and defined transformations
    # 获取dataset
    dataset = MarkDataset(root, get_transform(train=False))
    dataset_test = MarkDataset(root, get_transform(train=False))

    # dataset = VOCCustomData(root, transforms=get_transform(train=True))
    # dataset_test = VOCCustomData(root, transforms=get_transform(train=False))

    # split the dataset in train and test set
    # 我的数据集一共有492张图，差不多训练验证4:1
    # 我的数据集一共有5011张图，差不多训练验证4:1
    indices = torch.randperm(len(dataset)).tolist()
    dataset = torch.utils.data.Subset(dataset, indices[0:40])
    dataset_test = torch.utils.data.Subset(dataset_test, indices[0:5])

    # 设置dataloder
    # define training and validation data loaders
    # 在jupyter notebook里训练模型时num_workers参数只能为0，不然会报错，这里就把它注释掉了
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=1, shuffle=False, num_workers=12,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=12,
        collate_fn=utils.collate_fn)

    # 获取模型，一步到位直接获取模型
    # get the model using our helper function
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes,
                                                                 pretrained_backbone=True)  # 或get_object_detection_model(num_classes)

    # move model to the right device
    model.to(device)

    # 设置优化器，针对模型中可以训练的部分，将这些参数作为优化器优化的目的参数
    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]

    # SGD
    optimizer = torch.optim.SGD(params, lr=0.0003,
                                momentum=0.9, weight_decay=0.0005)

    # learning rate调度器，learning rate可以根据需要，在训练的不同阶段进行调度
    # and a learning rate scheduler
    # cos学习率
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2)

    # let's train it for   epochs
    num_epochs = 3

    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        # engine.py的train_one_epoch函数将images和targets都.to(device)了
        # 这里直接调用torchvision写好的接口
        # 这个接口需要一些东西：
        # datasets 数据集
        # optimizer 优化器
        # dataloder 数据加载器
        # decive 设备
        # epoch 训练次数
        # print_freq 打印信息频率
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=1)

        # update the learning rate
        # lr调度器有关操作
        lr_scheduler.step()

        # evaluate on the test dataset
        # 每个epoch都在测试集上验证一下
        evaluate(model, data_loader_test, device=device)

        print('')
        print('==================================================')
        print('')

    print("That's it!")

    torch.save(model, r'models\model3.pkl')

    model = torch.load(r'models\model3.pkl')
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    img, _ = dataset_test[0]
    showbbox(model, img)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description=__doc__)

    # parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
    # parser.add_argument('--dataset', default='coco', help='dataset')
    # parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
    # parser.add_argument('--device', default='cuda', help='device')
    # parser.add_argument('-b', '--batch-size', default=2, type=int,
    #                     help='images per gpu, the total batch size is $NGPU x batch_size')
    parser.add_argument('--epochs', default=26, type=int, metavar='N',
                        help='number of total epochs to run')
    # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
    #                     help='number of data loading workers (default: 4)')
    # parser.add_argument('--lr', default=0.02, type=float,
    #                     help='initial learning rate, 0.02 is the default value for training '
    #                     'on 8 gpus and 2 images_per_gpu')
    # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    #                     help='momentum')
    # parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
    #                     metavar='W', help='weight decay (default: 1e-4)',
    #                     dest='weight_decay')
    # parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    # parser.add_argument('--output-dir', default='.', help='path where to save')
    # parser.add_argument('--resume', default='', help='resume from checkpoint')
    # parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    # parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
    # parser.add_argument(
    #     "--test-only",
    #     dest="test_only",
    #     help="Only test the model",
    #     action="store_true",
    # )
    # parser.add_argument(
    #     "--pretrained",
    #     dest="pretrained",
    #     help="Use pre-trained models from the modelzoo",
    #     action="store_true",
    # )

    # distributed training parameters
    # parser.add_argument('--world-size', default=1, type=int,
    #                     help='number of distributed processes')
    # parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    # 将这些命令行参数传入主函数中运行
    main(args)
